{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Working with TopoTune"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we go over the basic workings of [TopoTune](https://arxiv.org/pdf/2410.06530), a comprehensive framework for easily defining and training new, general TDL models on any domain. These models, called Generalized Combinatorial Complex Neural Networks, are built using any (graph) neural network, which we will denote ω. \n",
    "\n",
    "In a GCCN (pictured below), the input complex--whether it be a hypergraph, cell complex, simplicial complex, or combinatorial complex--is represented as an ensemble of graphs (specifically, strictly augmented Hasse graphs), one per neighborhood of the complex. Each of these Hasse graphs is processed by a sub model GNN (ω), and the outputs are rank-wise aggregated in between layers. \n",
    "\n",
    "![gccn](https://github.com/user-attachments/assets/97747900-8e5e-401c-9ad9-764e16e1698e)\n",
    "**Generalized Combinatorial Complex Network (GCCN).** In this example, the input complex $\\mathcal{C}$ has neighborhoods $\\mathcal{N_C}$ = { $\\mathcal{N_1}$ , $\\mathcal{N_2}$, $\\mathcal{N_3}$ }. **A.** The complex is expanded into three augmented Hasse graphs $\\mathcal{G_\\mathcal{N_i}}$ , $i=\\{1,2,3\\}$, each with features $H_\\mathcal{N_i}$ represented as a colored disc. **B.** A GCCN layer dedicates one base architecture $\\omega_\\mathcal{N_i}$ **C.** The output of all the architectures $\\omega_\\mathcal{N_i}$ is aggregated rank-wise, then updated. In this example, only the complex's edge features (originally pink) are aggregated across multiple neighborhoods ($\\mathcal{N_2}$ and $\\mathcal{N_3}$)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table of contents<font><a class='anchor' id='top'></a>\n",
    "We will go over **three example cases** of using TopoTune for training and defining GCCNs.\n",
    "\n",
    "&emsp;[- Use Case A:](##sec2) A GCCN using a GNN available by import (GAT imported from PyG) and a dataset either available in TopoBench or available in PyG-like format.\n",
    "\n",
    "&emsp;[- Use Case B:](##sec3) A GCCN using a custom neural network.\n",
    "\n",
    "&emsp;[- Use Case C:](##sec5) Running large scale experiments considering many different possible versions of Use Case A, as is the case in the [TopoTune](https://arxiv.org/pdf/2410.06530) paper.\n",
    "\n",
    "In all of these cases, you are encouraged to try different options and exploit the flexibility of TopoTune. This could mean trying different combinations of neighborhoods, different sub-models, different architecture choices, different training schemes, or different datasets. The purpose of this Notebook is to allow these for such exploration without requiring greater knowledge of TopoBench.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports <a class=\"anchor\" id=\"sec1\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import lightning as pl\n",
    "# Hydra related imports\n",
    "from omegaconf import OmegaConf\n",
    "# Data related imports\n",
    "from topobench.data.loaders.graph import TUDatasetLoader\n",
    "from topobench.dataloader.dataloader import TBDataloader\n",
    "from topobench.data.preprocessor import PreProcessor\n",
    "# Model related imports\n",
    "from topobench.model.model import TBModel\n",
    "from topomodelx.nn.simplicial.scn2 import SCN2\n",
    "from topobench.nn.wrappers.simplicial import SCNWrapper\n",
    "from topobench.nn.encoders import AllCellFeatureEncoder\n",
    "from topobench.nn.readouts import PropagateSignalDown\n",
    "from topobench.nn.backbones.combinatorial.gccn import TopoTune\n",
    "from topobench.nn.wrappers.combinatorial import TuneWrapper\n",
    "from torch_geometric.nn import GAT\n",
    "# Optimization related imports\n",
    "from topobench.loss.loss import TBLoss\n",
    "from topobench.optimizer import TBOptimizer\n",
    "from topobench.evaluator.evaluator import TBEvaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Use Case A:** GCCN with imported GNN and dataset available in TopoBench <a class=\"anchor\" id=\"sec2\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we will define and train a GCCN using a GNN that is readily available in an imported package like `PyTorch Geometric` or `Deep Graph Library`. We will train and test the model with one of the many datasets avilabale in TopoBench. \n",
    "\n",
    "*Step 1 :* Define the choice of neighborhoods to be considered. To specify a set of neighborhoods on the complex, use a list of neighborhoods each specified as a string of the form \n",
    "`r-{neighborhood}-k`, where $k$ represents the source cell rank, and $r$ is the number of ranks up or down that the selected `{neighborhood}` considers. Currently, the following options for `{neighborhood}` are supported:\n",
    "- `up_laplacian`, between cells of rank $k$ through $k+r$ cells.\n",
    "- `down_laplacian`, between cells of rank $k$ through $k-r$ cells.\n",
    "- `hodge_laplacian`, between cells of rank $k$ through both $k-r$ and $k+r$ cells.\n",
    "- `up_adjacency`, between cells of rank $k$ through $k+r$ cells.\n",
    "- `down_adjacency`, between cells of rank $k$ through $k-r$ cells.\n",
    "- `up_incidence`, from rank $k$ to $k+r$.\n",
    "- `down_incidence`, from rank $k$ to $k-r$.\n",
    "\n",
    "The number $r$ can be omitted, in which case $r=1$ by default (e.g. `up_incidence-k` represents the incidence from rank $k$ to $k+1$).\n",
    "Here are some examples of neighborhoods with the stirng notation:\n",
    "\n",
    "- node to node (up-Laplacian), through edges : `up_laplacian-0`\n",
    "- node to node, through faces (up-Laplacian): `2-up_laplacian-0`\n",
    "- edge to node (boundary, also called incidence): `down_incidence-1`\n",
    "- face to edge (boundary): `down_incidence-2`\n",
    "- face to node (boundary): `2-down_incidence-2`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "neighborhoods = [\"1-up_laplacian-0\", \"1-down_incidence-1\", \"1-down_incidence-2\"]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Now we define the model channels, choice of dataset (see all TopoBench's readily available options [here]( https://www.youtube.com/embed/AR_3yRjFHYQ?si=_FD8x4S-wvpTbeLN)), choice of lifting (i.e. the choice of topological domain, see all options [here]( https://www.youtube.com/embed/AR_3yRjFHYQ?si=_FD8x4S-wvpTbeLN)), dataset split, and training scheme (readout, loss, evaluator, optimizer). Remark : when we run TopoBench from the command line, we rely on the yamls stored in `configs` to specify these choices (see Use Case D for command line examples)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channels = 7\n",
    "out_channels = 2\n",
    "dim_hidden = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_config = {\n",
    "    \"data_domain\": \"graph\", # the domain our dataset is orgiginally in\n",
    "    \"data_type\": \"TUDataset\",\n",
    "    \"data_name\": \"MUTAG\",\n",
    "    \"data_dir\": \"./data/MUTAG/\",\n",
    "    }\n",
    "\n",
    "transform_config = { \"cycle_lifting\": # see TopoBench docs for full list of available liftings \n",
    "    {\"transform_type\": \"lifting\",\n",
    "    \"transform_name\": \"CellCycleLifting\", # determines the domain our dataset will be in (i.e. cell complex)\n",
    "    \"neighborhoods\": neighborhoods,\n",
    "\n",
    "    }\n",
    "}\n",
    "\n",
    "split_config = {\n",
    "    \"learning_setting\": \"inductive\",\n",
    "    \"split_type\": \"random\",\n",
    "    \"data_seed\": 0,\n",
    "    \"data_split_dir\": \"./data/MUTAG/splits/\", # use name of dataset\n",
    "    \"train_prop\": 0.5,\n",
    "}\n",
    "\n",
    "readout_config = {\n",
    "    \"readout_name\": \"PropagateSignalDown\",\n",
    "    \"num_cell_dimensions\": 3,\n",
    "    \"hidden_dim\": dim_hidden,\n",
    "    \"out_channels\": out_channels,\n",
    "    \"task_level\": \"graph\",\n",
    "    \"pooling_type\": \"sum\",\n",
    "}\n",
    "\n",
    "loss_config = {\n",
    "    \"dataset_loss\": \n",
    "        {\n",
    "            \"task\": \"classification\", \n",
    "            \"loss_type\": \"cross_entropy\"\n",
    "        }\n",
    "}\n",
    "\n",
    "evaluator_config = {\"task\": \"classification\",\n",
    "                    \"num_classes\": out_channels,\n",
    "                    \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n",
    "\n",
    "optimizer_config = {\"optimizer_id\": \"Adam\",\n",
    "                    \"parameters\":\n",
    "                        {\"lr\": 0.001,\"weight_decay\": 0.0005}\n",
    "                    }\n",
    "\n",
    "loader_config = OmegaConf.create(loader_config)\n",
    "transform_config = OmegaConf.create(transform_config)\n",
    "split_config = OmegaConf.create(split_config)\n",
    "readout_config = OmegaConf.create(readout_config)\n",
    "loss_config = OmegaConf.create(loss_config)\n",
    "evaluator_config = OmegaConf.create(evaluator_config)\n",
    "optimizer_config = OmegaConf.create(optimizer_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Step 2 :* Load the data. In this example we use the MUTAG dataset on the cell domain. In order to transform the dataset from the the graph domain to the cell domain, we use the cycle lifting. The README of the [repository](https://github.com/geometric-intelligence/TopoBench?tab=readme-ov-file#rocket-liftings--transforms) has more information on the various liftings offered. \n",
    "\n",
    "Remark: if a user wanted to run a custom graph dataset not offered in TopoBench, it would be sufficient to check that it is formatted like a `PyTorchGeometric` graph dataset. It could then be passed to the `PreProcessor` class for lifting.\n",
    "\n",
    "Remark: the dataset needs to be re-loaded whenever the `neighborhood` object is modified.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transform parameters are the same, using existing data_dir: data/MUTAG/MUTAG/cycle_lifting/1611498484\n"
     ]
    }
   ],
   "source": [
    "graph_loader = TUDatasetLoader(loader_config)\n",
    "\n",
    "dataset, dataset_dir = graph_loader.load()\n",
    "\n",
    "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n",
    "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n",
    "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Step 4 :* Define the model. This is where we select our model to be a GCCN, and specify which GNN is used to build the GCCN. As with the choice of dataset, since the GNN is readily available (in this case, from PyTorch Geometric), all we need is to specify the config."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_gccn_model = GAT(in_channels=dim_hidden, hidden_channels=dim_hidden, num_layers=1, out_channels=dim_hidden, heads=2, v2=False)\n",
    "\n",
    "backbone_config = {\n",
    "    \"GNN\": sub_gccn_model,\n",
    "    \"neighborhoods\": neighborhoods,\n",
    "    \"layers\": 2,\n",
    "    \"use_edge_attr\": False,\n",
    "    \"activation\": \"relu\"\n",
    "}\n",
    "\n",
    "backbone = TopoTune(**backbone_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that the model is defined we can create the TBModel, which takes care of implementing everything else that is needed to train the model. We will define a feature encoder and readout to book-end the GCCN (the `backbone` of the model) as well as instantiate a `loss`, `evaluator`, and `optimizer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n",
    "readout = PropagateSignalDown(**readout_config)\n",
    "\n",
    "loss = TBLoss(**loss_config)\n",
    "evaluator = TBEvaluator(**evaluator_config)\n",
    "optimizer = TBOptimizer(**optimizer_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can instantiate the TBModel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "wrapper = TuneWrapper(backbone=backbone, out_channels=out_channels, num_cell_dimensions=3, residual_connections=False) # task_level=\"graph\", pooling_type=\"sum\")\n",
    "model = TBModel(backbone=wrapper,\n",
    "                 backbone_wrapper=None,\n",
    "                 readout=readout,\n",
    "                 loss=loss,\n",
    "                 feature_encoder=feature_encoder,\n",
    "                 evaluator=evaluator,\n",
    "                 optimizer=optimizer,\n",
    "                 compile=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Step 5 :* Define the training scheme. This is where we specify the training scheme to be used. In this case, we will use the default training scheme. We can use the `lightning` trainer to train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n",
      "\n",
      "  | Name            | Type                  | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | feature_encoder | AllCellFeatureEncoder | 1.3 K  | train\n",
      "1 | backbone        | TuneWrapper           | 3.6 K  | train\n",
      "2 | readout         | PropagateSignalDown   | 1.7 K  | train\n",
      "3 | val_acc_best    | MeanMetric            | 0      | train\n",
      "------------------------------------------------------------------\n",
      "6.6 K     Trainable params\n",
      "0         Non-trainable params\n",
      "6.6 K     Total params\n",
      "0.026     Total estimated model params size (MB)\n",
      "96        Modules in train mode\n",
      "0         Modules in eval mode\n",
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n",
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=50` reached.\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False, log_every_n_steps=1)\n",
    "trainer.fit(model, datamodule)\n",
    "train_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Training metrics\n",
      " --------------------------\n",
      "train/accuracy:       0.8511\n",
      "train/precision:      0.8296\n",
      "train/recall:         0.8397\n",
      "val/loss:             0.4852\n",
      "val/accuracy:         0.7234\n",
      "val/precision:        0.7009\n",
      "val/recall:           0.7260\n",
      "train/loss:           0.2851\n"
     ]
    }
   ],
   "source": [
    "print('      Training metrics\\n', '-'*26)\n",
    "for key in train_metrics:\n",
    "    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/papillon/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test/accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7659574747085571     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.3583049476146698     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test/precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7471264600753784     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test/recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7529411911964417     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test/accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7659574747085571    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.3583049476146698    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test/precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7471264600753784    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test/recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7529411911964417    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.test(model, datamodule)\n",
    "test_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Testing metrics\n",
      " -------------------------\n",
      "test/loss:           0.3583\n",
      "test/accuracy:       0.7660\n",
      "test/precision:      0.7471\n",
      "test/recall:         0.7529\n"
     ]
    }
   ],
   "source": [
    "print('      Testing metrics\\n', '-'*25)\n",
    "for key in test_metrics:\n",
    "    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Use Case B:** GCCN with custom GNN and dataset available in TopoBench <a class=\"anchor\" id=\"sec3\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this use case, we repeat the same process as in [Use Case A](##sec2), except that the sub-model we use to build the GCCN is a custom neural network, such as a GNN or otherwise. For our purposes, we will define a toy model below as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(pl.LightningModule):\n",
    "    def __init__(self, hidden_channels, out_channels):\n",
    "        super().__init__()\n",
    "        self.hidden_channels = hidden_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.linear_0 = torch.nn.Linear(hidden_channels, out_channels)\n",
    "        self.linear_1 = torch.nn.Linear(hidden_channels, out_channels)\n",
    "        self.linear_2 = torch.nn.Linear(hidden_channels, out_channels)\n",
    "\n",
    "    def forward(self, batch):\n",
    "        x_0 = batch.x_0\n",
    "        x_1 = batch.x_1\n",
    "        x_2 = batch.x_2\n",
    "        x_0 = self.linear_0(x_0)\n",
    "        x_0 = torch.relu(x_0)\n",
    "        x_1 = self.linear_1(x_1)\n",
    "        x_1 = torch.relu(x_1)\n",
    "        x_2 = self.linear_2(x_2)\n",
    "        x_2 = torch.relu(x_2)\n",
    "        \n",
    "        model_out = {\"labels\": batch.y, \"batch_0\": batch.batch_0}\n",
    "        model_out[\"x_0\"] = x_0\n",
    "        model_out[\"x_1\"] = x_1\n",
    "        model_out[\"x_2\"] = x_2\n",
    "        return model_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can build a GCCN with this custom model. Note that we increase the amount of GCCN layers (i.e., amount of sub-models) here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_sub_gccn_model = MyModel(dim_hidden, out_channels)\n",
    "\n",
    "backbone_config = {\n",
    "    \"GNN\": sub_gccn_model,\n",
    "    \"neighborhoods\": neighborhoods,\n",
    "    \"layers\": 4,\n",
    "    \"use_edge_attr\": False,\n",
    "    \"activation\": \"relu\"\n",
    "}\n",
    "\n",
    "backbone = TopoTune(**backbone_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can train this custom model as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "readout = PropagateSignalDown(**readout_config)\n",
    "loss = TBLoss(**loss_config)\n",
    "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n",
    "\n",
    "evaluator = TBEvaluator(**evaluator_config)\n",
    "optimizer = TBOptimizer(**optimizer_config)\n",
    "\n",
    "wrapper = TuneWrapper(backbone=backbone, out_channels=out_channels, num_cell_dimensions=3, residual_connections=False)\n",
    "model = TBModel(backbone=wrapper,\n",
    "                 backbone_wrapper=None,\n",
    "                 readout=readout,\n",
    "                 loss=loss,\n",
    "                 feature_encoder=feature_encoder,\n",
    "                 evaluator=evaluator,\n",
    "                 optimizer=optimizer,\n",
    "                 compile=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "\n",
      "  | Name            | Type                  | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | feature_encoder | AllCellFeatureEncoder | 1.3 K  | train\n",
      "1 | backbone        | TuneWrapper           | 7.1 K  | train\n",
      "2 | readout         | PropagateSignalDown   | 1.7 K  | train\n",
      "3 | val_acc_best    | MeanMetric            | 0      | train\n",
      "------------------------------------------------------------------\n",
      "10.1 K    Trainable params\n",
      "0         Non-trainable params\n",
      "10.1 K    Total params\n",
      "0.041     Total estimated model params size (MB)\n",
      "158       Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=10` reached.\n"
     ]
    }
   ],
   "source": [
    "# Increase the number of epochs to get better results\n",
    "trainer = pl.Trainer(max_epochs=10, accelerator=\"cpu\", enable_progress_bar=False, log_every_n_steps=1)\n",
    "\n",
    "trainer.fit(model, datamodule)\n",
    "train_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Training metrics\n",
      " --------------------------\n",
      "train/accuracy:       0.8617\n",
      "train/precision:      0.8407\n",
      "train/recall:         0.8559\n",
      "val/loss:             0.4541\n",
      "val/accuracy:         0.7447\n",
      "val/precision:        0.7180\n",
      "val/recall:           0.7417\n",
      "train/loss:           0.2961\n"
     ]
    }
   ],
   "source": [
    "print('      Training metrics\\n', '-'*26)\n",
    "for key in train_metrics:\n",
    "    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test/accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7446808218955994     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.4429433345794678     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test/precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7235294580459595     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test/recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7235294580459595     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test/accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7446808218955994    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.4429433345794678    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test/precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7235294580459595    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test/recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7235294580459595    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.test(model, datamodule)\n",
    "test_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Testing metrics\n",
      " -------------------------\n",
      "test/loss:           0.4429\n",
      "test/accuracy:       0.7447\n",
      "test/precision:      0.7235\n",
      "test/recall:         0.7235\n"
     ]
    }
   ],
   "source": [
    "print('      Testing metrics\\n', '-'*25)\n",
    "for key in test_metrics:\n",
    "    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **Use Case C:** Running large scale GCCN sweeps with TopoBench <a class=\"anchor\" id=\"sec4\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, we consider command line operations that will allow for rapid defining and testing of many possible GCCN architectures.To implement and train a GCCN from the command line, run the following with the desired choice of dataset, lifting domain (ex: `cell`, `simplicial`), PyTorch Geometric backbone model (ex: `GCN`, `GIN`, `GAT`, `GraphSAGE`) and parameters (ex. `model.backbone.GNN.num_layers=2`), neighborhood structure (routes), and other hyperparameters. To use a single augmented Hasse graph expansion, use `model={domain}/topotune_onehasse` instead of `model={domain}/topotune`.\n",
    "\n",
    "Here is an example of command to run for a single model:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "! python -m topobench \\\n",
    "    dataset=graph/MUTAG \\\n",
    "    dataset.split_params.data_seed=1 \\\n",
    "    model=cell/topotune\\\n",
    "    model.tune_gnn=GCN \\\n",
    "    model.backbone.GNN.num_layers=2 \\\n",
    "    model.backbone.neighborhoods=\\[1-up_laplacian-0,1-down_incidence-2\\] \\\n",
    "    model.backbone.layers=4 \\\n",
    "    model.feature_encoder.out_channels=32 \\\n",
    "    model.feature_encoder.proj_dropout=0.3 \\\n",
    "    model.readout.readout_name=PropagateSignalDown \\\n",
    "    logger.wandb.project=TopoTune_Tutorial \\\n",
    "    trainer.max_epochs=10 \\\n",
    "    callbacks.early_stopping.patience=50\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To extend this process to many GCCNs, it is sufficient to pass a list of options as an argument, as well as the `--multirun` flag. This is a shortcut for running every possible combination of the specified parameters in a single command.\n",
    "\n",
    "We provide an example of such a \"sweep\" command below:\n",
    "\n",
    "```\n",
    "! python -m topobench \\\n",
    "    dataset=graph/cocitation_cora \\\n",
    "    model=cell/topotune,cell/topotune_onehasse \\\n",
    "    model.feature_encoder.out_channels=32 \\\n",
    "    model.tune_gnn=GCN,GIN,GAT,GraphSAGE \\\n",
    "    model.backbone.GNN.num_layers=1,2 \\\n",
    "    model.backbone.neighborhoods=\\[1-up_laplacian-0,1-down_laplacian-1],\\[1-up_laplacian-0,1-down_incidence-2\\] \\\n",
    "    model.backbone.layers=2,4 \\\n",
    "    model.feature_encoder.proj_dropout=0.3 \\\n",
    "    dataset.split_params.data_seed=1,3,5,7,9 \\\n",
    "    model.readout.readout_name=PropagateSignalDown \\\n",
    "    logger.wandb.project=TopoTune_Tutorial \\\n",
    "    trainer.max_epochs=1000 \\\n",
    "    trainer.min_epochs=50 \\\n",
    "    trainer.devices=\\[1\\] \\\n",
    "    trainer.check_val_every_n_epoch=1 \\\n",
    "    callbacks.early_stopping.patience=50 \\\n",
    "    tags=\"[FirstExperiments]\" \\\n",
    "    --multirun\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using backbone models from any package\n",
    "By default, backbone models are imported from `torch_geometric.nn.models`. To import and specify a backbone model from any other package, such as `torch.nn.Transformer` or `dgl.nn.GATConv`, it is sufficient to 1) make sure the package is installed and 2) specify in the command line:\n",
    "\n",
    "```\n",
    "model.tune_gnn = {backbone_model}\n",
    "model.backbone.GNN._target_={package}.{backbone_model}\n",
    "```\n",
    "\n",
    "### Reproducing experiments\n",
    "\n",
    "We provide scripts to reproduce experiments on a broad class of GCCNs in [`scripts/topotune`](scripts/topotune) and reproduce iterations of existing neural networks in [`scripts/topotune/existing_models`](scripts/topotune/existing_models), as previously reported in the [TopoTune paper](https://arxiv.org/pdf/2410.06530)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tb",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
